{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ced17a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "\n",
    "from ark_nlp.nn import BertConfig as ModuleConfig\n",
    "from ark_nlp.dataset import TwinTowersSentenceClassificationDataset as Dataset\n",
    "from ark_nlp.processor.tokenizer.transfomer import SentenceTokenizer as Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfcea417",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 目录地址\n",
    "train_data_path = '../data/source_datasets/ATEC/ATEC.train.data'\n",
    "dev_data_path = '../data/source_datasets/ATEC/ATEC.test.data'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccea726a",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 一、数据读入与处理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d5c3337",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### 1. 数据读入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e651e5f-8535-43c2-893b-10f3275517ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_df = pd.read_csv(train_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])\n",
    "dev_data_df = pd.read_csv(dev_data_path, sep='\\t', header=None, names=['text_a', 'text_b', 'label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90876062",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosent_train_dataset = Dataset(train_data_df)\n",
    "cosent_dev_dataset = Dataset(dev_data_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e061890a",
   "metadata": {},
   "source": [
    "#### 2. 词典创建和生成分词器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be116454",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载分词器\n",
    "tokenizer = Tokenizer(vocab='bert-base-chinese', max_seq_len=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d6c3b3d",
   "metadata": {},
   "source": [
    "#### 3. ID化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "566dd6b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosent_train_dataset.convert_to_ids(tokenizer)\n",
    "cosent_dev_dataset.convert_to_ids(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "981b4160",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "### 二、模型构建"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72753ee8",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### 1. 模型参数设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527535d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertConfig\n",
    "\n",
    "bert_config = BertConfig.from_pretrained(\n",
    "    'bert-base-chinese',\n",
    "    num_labels=len(cosent_train_dataset.cat2id)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77d8b580",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "700d7752",
   "metadata": {},
   "source": [
    "#### 2. 模型创建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f380f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from torch import nn\n",
    "from transformers import BertModel\n",
    "from ark_nlp.nn import Bert\n",
    "\n",
    "\n",
    "class CoSENT(Bert):\n",
    "    \"\"\"\n",
    "    CoSENT模型\n",
    "\n",
    "    Args:\n",
    "        config:\n",
    "            模型的配置对象\n",
    "        encoder_trained (:obj:`bool`, optional, defaults to True):\n",
    "            bert参数是否可训练，默认可训练\n",
    "        pooling (:obj:`str`, optional, defaults to \"last_avg\"):\n",
    "            bert输出的池化方式，默认为\"last_avg\"，\n",
    "            可选有[\"cls\", \"cls_with_pooler\", \"first_last_avg\", \"last_avg\", \"last_2_avg\"]\n",
    "        dropout (:obj:`float` or :obj:`None`, optional, defaults to None):\n",
    "            dropout比例，默认为None，实际设置时会设置成0\n",
    "        output_emb_size (:obj:`int`, optional, defaults to 0):\n",
    "            输出的矩阵的维度，默认为0，即不进行矩阵维度变换\n",
    "\n",
    "    Reference:\n",
    "        [1] https://kexue.fm/archives/8847\n",
    "        [2] https://github.com/bojone/CoSENT \n",
    "    \"\"\"  # noqa: ignore flake8\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        config,\n",
    "        encoder_trained=True,\n",
    "        pooling='last_avg',\n",
    "        dropout=None,\n",
    "        output_emb_size=0\n",
    "    ):\n",
    "\n",
    "        super(CoSENT, self).__init__(config)\n",
    "\n",
    "        self.bert = BertModel(config)\n",
    "        self.pooling = pooling\n",
    "\n",
    "        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)\n",
    "\n",
    "        # if output_emb_size is greater than 0, then add Linear layer to reduce embedding_size,\n",
    "        # we recommend set output_emb_size = 256 considering the trade-off beteween\n",
    "        # recall performance and efficiency\n",
    "        self.output_emb_size = output_emb_size\n",
    "        if self.output_emb_size > 0:\n",
    "            self.emb_reduce_linear = nn.Linear(\n",
    "                config.hidden_size,\n",
    "                self.output_emb_size\n",
    "            )\n",
    "            torch.nn.init.trunc_normal_(\n",
    "                self.emb_reduce_linear.weight,\n",
    "                std=0.02\n",
    "            )\n",
    "\n",
    "        for param in self.bert.parameters():\n",
    "            param.requires_grad = encoder_trained\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def get_pooled_embedding(\n",
    "        self,\n",
    "        input_ids,\n",
    "        token_type_ids=None,\n",
    "        position_ids=None,\n",
    "        attention_mask=None\n",
    "    ):\n",
    "        outputs = self.bert(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            token_type_ids=token_type_ids,\n",
    "            position_ids=position_ids,\n",
    "            return_dict=True,\n",
    "            output_hidden_states=True\n",
    "        )\n",
    "\n",
    "        encoder_feature = self.get_encoder_feature(\n",
    "            outputs,\n",
    "            attention_mask\n",
    "        )\n",
    "\n",
    "        if self.output_emb_size > 0:\n",
    "            encoder_feature = self.emb_reduce_linear(encoder_feature)\n",
    "\n",
    "        encoder_feature = self.dropout(encoder_feature)\n",
    "        out = F.normalize(encoder_feature, p=2, dim=-1, eps=1e-8)\n",
    "\n",
    "        return out\n",
    "\n",
    "    def cosine_sim(\n",
    "        self,\n",
    "        input_ids_a,\n",
    "        input_ids_b,\n",
    "        token_type_ids_a=None,\n",
    "        position_ids_ids_a=None,\n",
    "        attention_mask_a=None,\n",
    "        token_type_ids_b=None,\n",
    "        position_ids_b=None,\n",
    "        attention_mask_b=None,\n",
    "        **kwargs\n",
    "    ):\n",
    "\n",
    "        query_cls_embedding = self.get_pooled_embedding(\n",
    "            input_ids_a,\n",
    "            token_type_ids_a,\n",
    "            position_ids_ids_a,\n",
    "            attention_mask_a\n",
    "        )\n",
    "\n",
    "        title_cls_embedding = self.get_pooled_embedding(\n",
    "            input_ids_b,\n",
    "            token_type_ids_b,\n",
    "            position_ids_b,\n",
    "            attention_mask_b\n",
    "        )\n",
    "\n",
    "        cosine_sim = torch.sum(\n",
    "            query_cls_embedding * title_cls_embedding,\n",
    "            axis=-1\n",
    "        )\n",
    "\n",
    "        return cosine_sim\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids_a,\n",
    "        input_ids_b,\n",
    "        token_type_ids_a=None,\n",
    "        position_ids_ids_a=None,\n",
    "        attention_mask_a=None,\n",
    "        token_type_ids_b=None,\n",
    "        position_ids_b=None,\n",
    "        attention_mask_b=None,\n",
    "        label_ids=None,\n",
    "        **kwargs\n",
    "    ):\n",
    "\n",
    "        cls_embedding_a = self.get_pooled_embedding(\n",
    "            input_ids_a,\n",
    "            token_type_ids_a,\n",
    "            position_ids_ids_a,\n",
    "            attention_mask_a\n",
    "        )\n",
    "\n",
    "        cls_embedding_b = self.get_pooled_embedding(\n",
    "            input_ids_b,\n",
    "            token_type_ids_b,\n",
    "            position_ids_b,\n",
    "            attention_mask_b\n",
    "        )\n",
    "\n",
    "        cosine_sim = torch.sum(cls_embedding_a * cls_embedding_b, dim=1) * 20\n",
    "        cosine_sim = cosine_sim[:, None] - cosine_sim[None, :]\n",
    "        \n",
    "        labels = label_ids[:, None] < label_ids[None, :]\n",
    "        labels = labels.long()\n",
    "        \n",
    "        cosine_sim = cosine_sim - (1 - labels) * 1e12\n",
    "        cosine_sim = torch.cat((torch.zeros(1).to(cosine_sim.device), cosine_sim.view(-1)), dim=0)\n",
    "        loss = torch.logsumexp(cosine_sim, dim=0)\n",
    "\n",
    "        return cosine_sim, loss\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e630530b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_module = CoSENT.from_pretrained(\n",
    "    'bert-base-chinese', \n",
    "    config=bert_config\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13e3c8ac",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "### 三、任务构建"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31d1f76c",
   "metadata": {},
   "source": [
    "#### 1. 任务参数和必要部件设定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943bf64c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置运行次数\n",
    "num_epoches = 5\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74641ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "param_optimizer = list(dl_module.named_parameters())\n",
    "param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]\n",
    "no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
    "     'weight_decay': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "]     "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd5a9361",
   "metadata": {},
   "source": [
    "#### 2. 任务创建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bc04465",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "from ark_nlp.factory.task.base._sequence_classification import SequenceClassificationTask\n",
    "\n",
    "\n",
    "class CoSENTTask(SequenceClassificationTask):\n",
    "    \"\"\"\n",
    "    用于CoSENT模型文本匹配任务的Task\n",
    "    \n",
    "    Args:\n",
    "        module: 深度学习模型\n",
    "        optimizer: 训练模型使用的优化器名或者优化器对象\n",
    "        loss_function: 训练模型使用的损失函数名或损失函数对象\n",
    "        class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目\n",
    "        scheduler (:obj:`class`, optional, defaults to None): scheduler对象\n",
    "        n_gpu (:obj:`int`, optional, defaults to 1): GPU数目\n",
    "        device (:obj:`class`, optional, defaults to None): torch.device对象，当device为None时，会自动检测是否有GPU\n",
    "        cuda_device (:obj:`int`, optional, defaults to 0): GPU编号，当device为None时，根据cuda_device设置device\n",
    "        ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数\n",
    "        **kwargs (optional): 其他可选参数\n",
    "    \"\"\"  # noqa: ignore flake8\"\n",
    "\n",
    "    def _on_evaluate_begin_record(self, **kwargs):\n",
    "\n",
    "        self.evaluate_logs['eval_loss'] = 0\n",
    "        self.evaluate_logs['eval_step'] = 0\n",
    "        self.evaluate_logs['eval_example'] = 0\n",
    "\n",
    "        self.evaluate_logs['labels'] = []\n",
    "        self.evaluate_logs['eval_sim'] = []\n",
    "\n",
    "    def _on_evaluate_step_end(self, inputs, outputs, **kwargs):\n",
    "\n",
    "        with torch.no_grad():\n",
    "            # compute loss\n",
    "            logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)\n",
    "            self.evaluate_logs['eval_loss'] += loss.item()\n",
    "\n",
    "            if 'label_ids' in inputs:\n",
    "                cosine_sim = self.module.cosine_sim(**inputs).cpu().numpy()\n",
    "                self.evaluate_logs['eval_sim'].append(cosine_sim)\n",
    "                self.evaluate_logs['labels'].append(inputs['label_ids'].cpu().numpy())\n",
    "\n",
    "        self.evaluate_logs['eval_example'] += logits.shape[0]\n",
    "        self.evaluate_logs['eval_step'] += 1\n",
    "\n",
    "    def _on_evaluate_epoch_end(\n",
    "        self,\n",
    "        validation_data,\n",
    "        epoch=1,\n",
    "        is_evaluate_print=True,\n",
    "        **kwargs\n",
    "    ):\n",
    "\n",
    "        if is_evaluate_print:\n",
    "            if 'labels' in self.evaluate_logs:\n",
    "                _sims = np.concatenate(self.evaluate_logs['eval_sim'], axis=0)\n",
    "                _labels = np.concatenate(self.evaluate_logs['labels'], axis=0)\n",
    "                spearman_corr = stats.spearmanr(_labels, _sims).correlation\n",
    "                print('evaluate spearman corr is:{:.4f}, evaluate loss is:{:.6f}'.format(\n",
    "                    spearman_corr,\n",
    "                    self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']\n",
    "                    )\n",
    "                )\n",
    "            else:\n",
    "                print('evaluate loss is:{:.6f}'.format(self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3dfc61d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CoSENTTask(dl_module, 'adamw', None, cuda_device=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c96cf8",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### 3. 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62e3e9ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    cosent_train_dataset,\n",
    "    cosent_dev_dataset,\n",
    "    lr=2e-5,\n",
    "    epochs=num_epoches,\n",
    "    batch_size=batch_size,\n",
    "    params=optimizer_grouped_parameters\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b27e57b",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "### 四、模型验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0800c2a3-2b57-435a-87d7-cfafbbc69fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from ark_nlp.factory.predictor import SequenceClassificationPredictor\n",
    "\n",
    "\n",
    "class CoSENTPredictor(SequenceClassificationPredictor):\n",
    "    \"\"\"\n",
    "    CoSENT的预测器\n",
    "    \n",
    "    Args:\n",
    "        module: 深度学习模型\n",
    "        tokernizer: 分词器\n",
    "        cat2id (:obj:`dict`): 标签映射\n",
    "    \"\"\"  # noqa: ignore flake8\"\n",
    "\n",
    "    def _get_input_ids(\n",
    "        self,\n",
    "        text_a,\n",
    "        text_b\n",
    "    ):\n",
    "        if self.tokenizer.tokenizer_type == 'vanilla':\n",
    "            return self._convert_to_vanilla_ids(text_a, text_b)\n",
    "        elif self.tokenizer.tokenizer_type == 'transfomer':\n",
    "            return self._convert_to_transfomer_ids(text_a, text_b)\n",
    "        elif self.tokenizer.tokenizer_type == 'customized':\n",
    "            return self._convert_to_customized_ids(text_a, text_b)\n",
    "        else:\n",
    "            raise ValueError(\"The tokenizer type does not exist\")\n",
    "\n",
    "    def _convert_to_transfomer_ids(\n",
    "        self,\n",
    "        text_a,\n",
    "        text_b\n",
    "    ):\n",
    "        input_ids_a = self.tokenizer.sequence_to_ids(text_a)\n",
    "        input_ids_b = self.tokenizer.sequence_to_ids(text_b)\n",
    "\n",
    "        input_ids_a, input_mask_a, segment_ids_a = input_ids_a\n",
    "        input_ids_b, input_mask_b, segment_ids_b = input_ids_b\n",
    "\n",
    "        features = {\n",
    "            'input_ids_a': input_ids_a,\n",
    "            'attention_mask_a': input_mask_a,\n",
    "            'token_type_ids_a': segment_ids_a,\n",
    "            'input_ids_b': input_ids_b,\n",
    "            'attention_mask_b': input_mask_b,\n",
    "            'token_type_ids_b': segment_ids_b\n",
    "        }\n",
    "\n",
    "        return features\n",
    "\n",
    "    def predict_one_sample(\n",
    "        self,\n",
    "        text,\n",
    "        topk=None,\n",
    "        threshold=0.5,\n",
    "        return_label_name=True,\n",
    "        return_proba=False\n",
    "    ):\n",
    "        if topk is None:\n",
    "            topk = len(self.cat2id) if len(self.cat2id) > 2 else 1\n",
    "        text_a, text_b = text\n",
    "        features = self._get_input_ids(text_a, text_b)\n",
    "        self.module.eval()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            inputs = self._get_module_one_sample_inputs(features)\n",
    "            logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
    "\n",
    "        _proba = logits[0]\n",
    "        \n",
    "        if threshold is not None:\n",
    "            _pred = self._threshold(_proba, threshold)\n",
    "\n",
    "        if return_label_name and threshold is not None:\n",
    "            _pred = self.id2cat[_pred]\n",
    "\n",
    "        if threshold is not None:\n",
    "            if return_proba:\n",
    "                return [_pred, _proba]\n",
    "            else:\n",
    "                return _pred\n",
    "\n",
    "        return _proba\n",
    "\n",
    "    def predict_batch(\n",
    "        self,\n",
    "        test_data,\n",
    "        batch_size=16,\n",
    "        shuffle=False\n",
    "    ):\n",
    "        self.inputs_cols = test_data.dataset_cols\n",
    "\n",
    "        preds = []\n",
    "\n",
    "        self.module.eval()\n",
    "        generator = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for step, inputs in enumerate(generator):\n",
    "                inputs = self._get_module_batch_inputs(inputs)\n",
    "\n",
    "                logits = self.module.cosine_sim(**inputs).cpu().numpy()\n",
    "\n",
    "                preds.extend(logits)\n",
    "\n",
    "        return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f416a92-4e5d-4f78-91f1-b42f9db71b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosent_predictor_instance = CoSENTPredictor(model.module, tokenizer, cosent_train_dataset.cat2id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55fe09db-9e27-4300-b57f-df6603643e50",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosent_predictor_instance.predict_one_sample(\n",
    "    ['怎么花呗不可以用来生活缴费了呀', \n",
    "     '怎么我的花呗不能付电费了'], \n",
    "    threshold=None\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
